{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ISubpr_SSsiM"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "3jTMb1dySr3V"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6DWfyNThSziV"
      },
      "source": [
        "# 模块、层和模型简介\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/intro_to_modules\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/intro_to_modules.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/intro_to_modules.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 Github 上查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/intro_to_modules.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\">下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v0DdlfacAdTZ"
      },
      "source": [
        "要进行 TensorFlow 机器学习，您可能需要定义、保存和恢复模型。\n",
        "\n",
        "抽象地说，模型是：\n",
        "\n",
        "- 一个在张量上进行某些计算的函数（**前向传递**）\n",
        "- 一些可以更新以响应训练的变量\n",
        "\n",
        "在本指南中，您将深入学习 Keras，了解如何定义 TensorFlow 模型。本文着眼于 TensorFlow 如何收集变量和模型，以及如何保存和恢复它们。\n",
        "\n",
        "注：如果您想立即开始使用 Keras，请参阅 [Keras 指南集合](./keras/)。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VSa6ayJmfZxZ"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "goZwOXp_xyQj"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from datetime import datetime\n",
        "\n",
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yt5HEbsYAbw1"
      },
      "source": [
        "## 在 TensorFlow 中定义模型和层\n",
        "\n",
        "大多数模型都由层组成。层是具有已知数学结构的函数，可以重复使用并且具有可训练的变量。在 TensorFlow 中，层和模型的大多数高级实现（例如 Keras 或 Sonnet）都在以下同一个基础类上构建：`tf.Module`。\n",
        "\n",
        "下面是一个在标量张量上运行的非常简单的 `tf.Module` 示例：\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "alhYPVEtAiSy"
      },
      "outputs": [],
      "source": [
        "class SimpleModule(tf.Module):\n",
        "  def __init__(self, name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.a_variable = tf.Variable(5.0, name=\"train_me\")\n",
        "    self.non_trainable_variable = tf.Variable(5.0, trainable=False, name=\"do_not_train_me\")\n",
        "  def __call__(self, x):\n",
        "    return self.a_variable * x + self.non_trainable_variable\n",
        "\n",
        "simple_module = SimpleModule(name=\"simple\")\n",
        "\n",
        "simple_module(tf.constant(5.0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JwMc_zu5Ant8"
      },
      "source": [
        "模块和引申而来的层是“对象”的深度学习术语：它们具有内部状态以及使用该状态的方法。\n",
        "\n",
        "`__call__` 并无特殊之处，只是其行为与 [Python 可调用对象](https://stackoverflow.com/questions/111234/what-is-a-callable)类似；您可以使用任何函数来调用模型。\n",
        "\n",
        "您可以出于任何原因开启和关闭变量的可训练性，包括在微调过程中冻结层和变量。\n",
        "\n",
        "注：tf.Module 是 tf.keras.layers.Layer 和 tf.keras.Model 的基类，因此您在此处看到的一切内容也适用于 Keras。出于历史兼容性的原因，Keras 层不会从模块收集变量，因此您的模型应仅使用模块或仅使用 Keras 层。不过，下面给出的用于检查变量的方法相同在这两种情况下相同。\n",
        "\n",
        "通过将 `tf.Module` 子类化，将自动收集分配给该对象属性的任何 `tf.Variable` 或 `tf.Module` 实例。这样，您可以保存和加载变量，还可以创建 `tf.Module` 的集合。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CyzYy4A_CbVf"
      },
      "outputs": [],
      "source": [
        "# All trainable variables\n",
        "print(\"trainable variables:\", simple_module.trainable_variables)\n",
        "# Every variable\n",
        "print(\"all variables:\", simple_module.variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nuSFrRUNCaaW"
      },
      "source": [
        "下面是一个由模块组成的两层线性层模型的示例。\n",
        "\n",
        "首先是一个密集（线性）层："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Efb2p2bzAn-V"
      },
      "outputs": [],
      "source": [
        "class Dense(tf.Module):\n",
        "  def __init__(self, in_features, out_features, name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.w = tf.Variable(\n",
        "      tf.random.normal([in_features, out_features]), name='w')\n",
        "    self.b = tf.Variable(tf.zeros([out_features]), name='b')\n",
        "  def __call__(self, x):\n",
        "    y = tf.matmul(x, self.w) + self.b\n",
        "    return tf.nn.relu(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bAhMuC-UpnhX"
      },
      "source": [
        "随后是完整的模型，此模型将创建并应用两个层实例。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QQ7qQf-DFw74"
      },
      "outputs": [],
      "source": [
        "class SequentialModule(tf.Module):\n",
        "  def __init__(self, name=None):\n",
        "    super().__init__(name=name)\n",
        "\n",
        "    self.dense_1 = Dense(in_features=3, out_features=3)\n",
        "    self.dense_2 = Dense(in_features=3, out_features=2)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    x = self.dense_1(x)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "# You have made a model!\n",
        "my_model = SequentialModule(name=\"the_model\")\n",
        "\n",
        "# Call it, with random results\n",
        "print(\"Model results:\", my_model(tf.constant([[2.0, 2.0, 2.0]])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d1oUzasJHHXf"
      },
      "source": [
        "`tf.Module` 实例将以递归方式自动收集分配给它的任何 `tf.Variable` 或 `tf.Module` 实例。这样，您可以使用单个模型实例管理 `tf.Module` 的集合，并保存和加载整个模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JLFA5_PEGb6C"
      },
      "outputs": [],
      "source": [
        "print(\"Submodules:\", my_model.submodules)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6lzoB8pcRN12"
      },
      "outputs": [],
      "source": [
        "for var in my_model.variables:\n",
        "  print(var, \"\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hoaxL3zzm0vK"
      },
      "source": [
        "### 等待创建变量\n",
        "\n",
        "您在这里可能已经注意到，必须定义层的输入和输出大小。这样，`w` 变量才会具有已知的形状并且可被分配。\n",
        "\n",
        "通过将变量创建推迟到第一次使用特定输入形状调用模块时，您将无需预先指定输入大小。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XsGCLFXlnPum"
      },
      "outputs": [],
      "source": [
        "class FlexibleDenseModule(tf.Module):\n",
        "  # Note: No need for `in+features`\n",
        "  def __init__(self, out_features, name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.is_built = False\n",
        "    self.out_features = out_features\n",
        "\n",
        "  def __call__(self, x):\n",
        "    # Create variables on first call.\n",
        "    if not self.is_built:\n",
        "      self.w = tf.Variable(\n",
        "        tf.random.normal([x.shape[-1], self.out_features]), name='w')\n",
        "      self.b = tf.Variable(tf.zeros([self.out_features]), name='b')\n",
        "      self.is_built = True\n",
        "\n",
        "    y = tf.matmul(x, self.w) + self.b\n",
        "    return tf.nn.relu(y)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8bjOWax9LOkP"
      },
      "outputs": [],
      "source": [
        "# Used in a module\n",
        "class MySequentialModule(tf.Module):\n",
        "  def __init__(self, name=None):\n",
        "    super().__init__(name=name)\n",
        "\n",
        "    self.dense_1 = FlexibleDenseModule(out_features=3)\n",
        "    self.dense_2 = FlexibleDenseModule(out_features=2)\n",
        "\n",
        "  def __call__(self, x):\n",
        "    x = self.dense_1(x)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "my_model = MySequentialModule(name=\"the_model\")\n",
        "print(\"Model results:\", my_model(tf.constant([[2.0, 2.0, 2.0]])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "49JfbhVrpOLH"
      },
      "source": [
        "这种灵活性是 TensorFlow 层通常仅需要指定其输出的形状（例如在 `tf.keras.layers.Dense` 中），而无需指定输入和输出大小的原因。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JOLVVBT8J_dl"
      },
      "source": [
        "## 保存权重\n",
        "\n",
        "您可以将 `tf.Module` 保存为[检查点](./checkpoint.ipynb)和 [SavedModel](./saved_model.ipynb)。\n",
        "\n",
        "检查点即是权重（即模块及其子模块内部的变量集的值）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pHXKRDk7OLHA"
      },
      "outputs": [],
      "source": [
        "chkp_path = \"my_checkpoint\"\n",
        "checkpoint = tf.train.Checkpoint(model=my_model)\n",
        "checkpoint.write(chkp_path)\n",
        "checkpoint.write(chkp_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WXOPMBR4T4ZR"
      },
      "source": [
        "检查点由两种文件组成---数据本身以及元数据的索引文件。索引文件跟踪实际保存的内容和检查点的编号，而检查点数据包含变量值及其特性查找路径。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jBV3fprlTWqJ"
      },
      "outputs": [],
      "source": [
        "!ls my_checkpoint*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CowCuBTvXgUu"
      },
      "source": [
        "您可以查看检查点内部，以确保整个变量集合已由包含这些变量的 Python 对象保存并排序。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o2QAdfpvS8tB"
      },
      "outputs": [],
      "source": [
        "tf.train.list_variables(chkp_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4eGaNiQWcK4j"
      },
      "source": [
        "在分布式（多机）训练期间，可以将它们分片，这就是要对它们进行编号（例如 '00000-of-00001'）的原因。不过，在本例中，只有一个分片。\n",
        "\n",
        "重新加载模型时，将重写 Python 对象中的值。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UV8rdDzcwVVg"
      },
      "outputs": [],
      "source": [
        "new_model = MySequentialModule()\n",
        "new_checkpoint = tf.train.Checkpoint(model=new_model)\n",
        "new_checkpoint.restore(\"my_checkpoint\")\n",
        "\n",
        "# Should be the same result as above\n",
        "new_model(tf.constant([[2.0, 2.0, 2.0]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BnPwDRwamdfq"
      },
      "source": [
        "注：由于检查点处于长时间训练工作流的核心位置，因此 `tf.checkpoint.CheckpointManager` 是一个可使检查点管理变得更简单的辅助类。有关更多详细信息，请参阅[指南](./checkpoint.ipynb)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pSZebVuWxDXu"
      },
      "source": [
        "## 保存函数\n",
        "\n",
        "TensorFlow 可以在不使用原始 Python 对象的情况下运行模型，如 [TensorFlow Serving](https://tensorflow.org/tfx) 和 [TensorFlow Lite](https://tensorflow.org/lite) 中所见，甚至当您从 [TensorFlow Hub](https://tensorflow.org/hub) 下载经过训练的模型时也是如此。\n",
        "\n",
        "TensorFlow 需要了解如何执行 Python 中描述的计算，但**不需要原始代码**。为此，您可以创建一个**计算图**，如上一篇[指南](./intro_to_graphs.ipynb)中所述。\n",
        "\n",
        "此计算图中包含实现函数的*运算*。\n",
        "\n",
        "您可以通过添加 `@tf.function` 装饰器在上面的模型中定义计算图，以指示此代码应作为计算图运行。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WQTvkapUh7lk"
      },
      "outputs": [],
      "source": [
        "class MySequentialModule(tf.Module):\n",
        "  def __init__(self, name=None):\n",
        "    super().__init__(name=name)\n",
        "\n",
        "    self.dense_1 = Dense(in_features=3, out_features=3)\n",
        "    self.dense_2 = Dense(in_features=3, out_features=2)\n",
        "\n",
        "  @tf.function\n",
        "  def __call__(self, x):\n",
        "    x = self.dense_1(x)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "# You have made a model with a graph!\n",
        "my_model = MySequentialModule(name=\"the_model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hW66YXBziLo9"
      },
      "source": [
        "您创建的模块的工作原理与之前完全相同。传递给函数的每个唯一签名都会创建一个单独的计算图。有关详细信息，请参阅[计算图指南](./intro_to_graphs.ipynb)。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H5zUfti3iR52"
      },
      "outputs": [],
      "source": [
        "print(my_model([[2.0, 2.0, 2.0]]))\n",
        "print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lbGlU1kgyDo7"
      },
      "source": [
        "您可以通过在 TensorBoard 摘要中跟踪计算图来将其可视化。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zmy-T67zhp-S"
      },
      "outputs": [],
      "source": [
        "# Set up logging.\n",
        "stamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
        "logdir = \"logs/func/%s\" % stamp\n",
        "writer = tf.summary.create_file_writer(logdir)\n",
        "\n",
        "# Create a new model to get a fresh trace\n",
        "# Otherwise the summary will not see the graph.\n",
        "new_model = MySequentialModule()\n",
        "\n",
        "# Bracket the function call with\n",
        "# tf.summary.trace_on() and tf.summary.trace_export().\n",
        "tf.summary.trace_on(graph=True, profiler=True)\n",
        "# Call only one tf.function when tracing.\n",
        "z = print(new_model(tf.constant([[2.0, 2.0, 2.0]])))\n",
        "with writer.as_default():\n",
        "  tf.summary.trace_export(\n",
        "      name=\"my_func_trace\",\n",
        "      step=0,\n",
        "      profiler_outdir=logdir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gz4lwNZ9hR79"
      },
      "source": [
        "启动 Tensorboard 以查看生成的跟踪："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V4MXDbgBnkJu"
      },
      "outputs": [],
      "source": [
        "#docs_infra: no_execute\n",
        "%tensorboard --logdir logs/func"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gjattu0AhYUl"
      },
      "source": [
        "![A screenshot of the graph, in tensorboard](https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/images/tensorboard_graph.png?raw=true)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SQu3TVZecmL7"
      },
      "source": [
        "### 创建 `SavedModel`\n",
        "\n",
        "共享经过完全训练的模型的推荐方式是使用 `SavedModel`。`SavedModel` 包含函数集合与权重集合。\n",
        "\n",
        "您可以保存刚刚创建的模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Awv_Tw__WK7a"
      },
      "outputs": [],
      "source": [
        "tf.saved_model.save(my_model, \"the_saved_model\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SXv3mEKsefGj"
      },
      "outputs": [],
      "source": [
        "# Inspect the in the directory\n",
        "!ls -l the_saved_model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vQQ3hEvHYdoR"
      },
      "outputs": [],
      "source": [
        "# The variables/ directory contains a checkpoint of the variables \n",
        "!ls -l the_saved_model/variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xBqPop7ZesBU"
      },
      "source": [
        "`saved_model.pb` 文件是一个描述函数式 `tf.Graph` 的[协议缓冲区](https://developers.google.com/protocol-buffers)。\n",
        "\n",
        "可以从此表示加载模型和层，而无需实际构建创建该表示的类的实例。在您没有（或不需要）Python 解释器（例如大规模应用或在边缘设备上），或者在原始 Python 代码不可用或不实用的情况下，这样做十分理想。\n",
        "\n",
        "您可以将模型作为新对象加载："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zRFcA5wIefv4"
      },
      "outputs": [],
      "source": [
        "new_model = tf.saved_model.load(\"the_saved_model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-9EF3mT7i3qN"
      },
      "source": [
        "通过加载已保存模型创建的 `new_model` 是 TensorFlow 内部的用户对象，无需任何类知识。它不是 `SequentialModule` 类型的对象。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EC_eQj7yi54G"
      },
      "outputs": [],
      "source": [
        "isinstance(new_model, SequentialModule)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-OrOX1zxiyhR"
      },
      "source": [
        "此新模型​​适用于已定义的输入签名。您不能向以这种方式恢复的模型添加更多签名。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_23BYYBWfKnc"
      },
      "outputs": [],
      "source": [
        "print(my_model([[2.0, 2.0, 2.0]]))\n",
        "print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qSFhoMtTjSR6"
      },
      "source": [
        "因此，利用 `SavedModel`，您可以使用 `tf.Module` 保存 TensorFlow 权重和计算图，随后再次加载它们。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rb9IdN7hlUZK"
      },
      "source": [
        "## Keras 模型和层\n",
        "\n",
        "请注意，到目前为止，还没有提到 Keras。您可以在 `tf.Module` 上构建自己的高级 API，而我们已经拥有这些 API。\n",
        "\n",
        "在本部分中，您将研究 Keras 如何使用 `tf.Module`。可在 [Keras 指南](keras/sequential_model.ipynb)中找到有关 Keras 模型的完整用户指南。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uigsVGPreE-D"
      },
      "source": [
        "### Keras 层\n",
        "\n",
        "`tf.keras.layers.Layer` 是所有 Keras 层的基类，它继承自 `tf.Module`。\n",
        "\n",
        "您只需换出父项，然后将 `__call__` 更改为 `call` 即可将模块转换为 Keras 层："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "88YOGquhnQRd"
      },
      "outputs": [],
      "source": [
        "class MyDense(tf.keras.layers.Layer):\n",
        "  # Adding **kwargs to support base Keras layer arguemnts\n",
        "  def __init__(self, in_features, out_features, **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "\n",
        "    # This will soon move to the build step; see below\n",
        "    self.w = tf.Variable(\n",
        "      tf.random.normal([in_features, out_features]), name='w')\n",
        "    self.b = tf.Variable(tf.zeros([out_features]), name='b')\n",
        "  def call(self, x):\n",
        "    y = tf.matmul(x, self.w) + self.b\n",
        "    return tf.nn.relu(y)\n",
        "\n",
        "simple_layer = MyDense(name=\"simple\", in_features=3, out_features=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nYGmAsPrws--"
      },
      "source": [
        "Keras 层有自己的 `__call__`，它会进行下一部分中所述的某些簿记，然后调用 `call()`。您应当不会看到功能上的任何变化。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nIqE8wOznYKG"
      },
      "outputs": [],
      "source": [
        "simple_layer([[2.0, 2.0, 2.0]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tmN5vb1K18U1"
      },
      "source": [
        "### `build` 步骤\n",
        "\n",
        "如上所述，在您确定输入形状之前，等待创建变量在许多情况下十分方便。\n",
        "\n",
        "Keras 层具有额外的生命周期步骤，可让您在定义层时获得更高的灵活性。这是在 `build()` 函数中定义的。\n",
        "\n",
        "`build` 仅被调用一次，而且是使用输入的形状调用的。它通常用于创建变量（权重）。\n",
        "\n",
        "您可以根据输入的大小灵活地重写上面的 `MyDense` 层。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4YTfrlgdsURp"
      },
      "outputs": [],
      "source": [
        "class FlexibleDense(tf.keras.layers.Layer):\n",
        "  # Note the added `**kwargs`, as Keras supports many arguments\n",
        "  def __init__(self, out_features, **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "    self.out_features = out_features\n",
        "\n",
        "  def build(self, input_shape):  # Create the state of the layer (weights)\n",
        "    self.w = tf.Variable(\n",
        "      tf.random.normal([input_shape[-1], self.out_features]), name='w')\n",
        "    self.b = tf.Variable(tf.zeros([self.out_features]), name='b')\n",
        "\n",
        "  def call(self, inputs):  # Defines the computation from inputs to outputs\n",
        "    return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "# Create the instance of the layer\n",
        "flexible_dense = FlexibleDense(out_features=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Koc_uSqt2PRh"
      },
      "source": [
        "此时，模型尚未构建，因此没有变量。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DgyTyUD32Ln4"
      },
      "outputs": [],
      "source": [
        "flexible_dense.variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-KdamIVl2W8Y"
      },
      "source": [
        "调用该函数会分配大小适当的变量。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IkLyEx7uAoTK"
      },
      "outputs": [],
      "source": [
        "# Call it, with predictably random results\n",
        "print(\"Model results:\", flexible_dense(tf.constant([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Swofpkrd2YDd"
      },
      "outputs": [],
      "source": [
        "flexible_dense.variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7PuNUnf0OIpF"
      },
      "source": [
        "由于仅调用一次 `build`，因此如果输入形状与层的变量不兼容，输入将被拒绝。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "caYWDrHSAy_j"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  print(\"Model results:\", flexible_dense(tf.constant([[2.0, 2.0, 2.0, 2.0]])))\n",
        "except tf.errors.InvalidArgumentError as e:\n",
        "  print(\"Failed:\", e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YnporXiudF1I"
      },
      "source": [
        "Keras 层具有许多额外的功能，包括：\n",
        "\n",
        "- 可选损失\n",
        "- 对指标的支持\n",
        "- 对可选 `training` 参数的内置支持，用于区分训练和推断用途\n",
        "- `get_config` 和 `from_config` 方法，允许您准确存储配置以在 Python 中克隆模型\n",
        "\n",
        "在自定义层的[完整指南](./keras/custom_layers_and_models.ipynb)中阅读关于它们的信息。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L2kds2IHw2KD"
      },
      "source": [
        "### Keras 模型\n",
        "\n",
        "您可以将模型定义为嵌套的 Keras 层。\n",
        "\n",
        "但是，Keras 还提供了称为 `tf.keras.Model` 的全功能模型类。它继承自 `tf.keras.layers.Layer`，因此 Keras 模型是一种 Keras 层，支持以同样的方式使用、嵌套和保存。Keras 模型还具有额外的功能，这使它们可以轻松训练、评估、加载、保存，甚至在多台机器上进行训练。\n",
        "\n",
        "您可以使用几乎相同的代码定义上面的 `SequentialModule`，再次将 `__call__` 转换为 `call()` 并更改父项。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hqjo1DiyrHrn"
      },
      "outputs": [],
      "source": [
        "class MySequentialModel(tf.keras.Model):\n",
        "  def __init__(self, name=None, **kwargs):\n",
        "    super().__init__(**kwargs)\n",
        "\n",
        "    self.dense_1 = FlexibleDense(out_features=3)\n",
        "    self.dense_2 = FlexibleDense(out_features=2)\n",
        "  def call(self, x):\n",
        "    x = self.dense_1(x)\n",
        "    return self.dense_2(x)\n",
        "\n",
        "# You have made a Keras model!\n",
        "my_sequential_model = MySequentialModel(name=\"the_model\")\n",
        "\n",
        "# Call it on a tensor, with random results\n",
        "print(\"Model results:\", my_sequential_model(tf.constant([[2.0, 2.0, 2.0]])))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8i-CR_h2xw3z"
      },
      "source": [
        "所有相同的功能都可用，包括跟踪变量和子模块。\n",
        "\n",
        "注：为了强调上面的注意事项，嵌套在 Keras 层或模型中的原始 `tf.Module` 将不会收集其变量以用于训练或保存。相反，它会在 Keras 层内嵌套 Keras 层。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hdLQFNdMsOz1"
      },
      "outputs": [],
      "source": [
        "my_sequential_model.variables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JjVAMrAJsQ7G"
      },
      "outputs": [],
      "source": [
        "my_sequential_model.submodules"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FhP8EItC4oac"
      },
      "source": [
        "重写 `tf.keras.Model` 是一种构建 TensorFlow 模型的极 Python 化方式。如果要从其他框架迁移模型，这可能非常简单。\n",
        "\n",
        "如果要构造的模型是现有层和输入的简单组合，则可以使用[函数式 API](./keras/functional.ipynb) 节省时间和空间，此 API 附带有关模型重构和架构的附加功能。\n",
        "\n",
        "下面是使用函数式 API 构造的相同模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jJiZZiJ0fyqQ"
      },
      "outputs": [],
      "source": [
        "inputs = tf.keras.Input(shape=[3,])\n",
        "\n",
        "x = FlexibleDense(3)(inputs)\n",
        "x = FlexibleDense(2)(x)\n",
        "\n",
        "my_functional_model = tf.keras.Model(inputs=inputs, outputs=x)\n",
        "\n",
        "my_functional_model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kg-xAZw5gaG6"
      },
      "outputs": [],
      "source": [
        "my_functional_model(tf.constant([[2.0, 2.0, 2.0]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s_BK9XH5q9cq"
      },
      "source": [
        "这里的主要区别在于，输入形状是作为函数构造过程的一部分预先指定的。在这种情况下，不必完全指定 `input_shape` 参数；您可以将某些维度保留为 `None`。\n",
        "\n",
        "注：您无需在子类化模型中指定 `input_shape` 或 `InputLayer`；这些参数和层将被忽略。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qI9aXLnaHEFF"
      },
      "source": [
        "## 保存 Keras 模型\n",
        "\n",
        "可以为 Keras 模型创建检查点，这看起来和 `tf.Module` 一样。\n",
        "\n",
        "Keras 模型也可以使用 `tf.saved_models.save()` 保存，因为它们是模块。但是，Keras 模型具有更方便的方法和其他功能。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SAz-KVZlzAJu"
      },
      "outputs": [],
      "source": [
        "my_sequential_model.save(\"exname_of_file\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C2urAeR-omns"
      },
      "source": [
        "同样地，它们也可以轻松重新加载。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wj5DW-LCopry"
      },
      "outputs": [],
      "source": [
        "reconstructed_model = tf.keras.models.load_model(\"exname_of_file\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EA7P_MNvpviZ"
      },
      "source": [
        "Keras `SavedModels` 还可以保存指标、损失和优化器状态。\n",
        "\n",
        "可以使用此重构模型，并且在相同数据上调用时会产生相同的结果。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P_wGfQo5pe6T"
      },
      "outputs": [],
      "source": [
        "reconstructed_model(tf.constant([[2.0, 2.0, 2.0]]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xKyjlkceqjwD"
      },
      "source": [
        "有关保存和序列化 Keras 模型，包括为自定义层提供配置方法来为功能提供支持的更多信息，请参阅[保存和序列化指南](keras/save_and_serialize)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kcdMMPYv7Krz"
      },
      "source": [
        "# 后续步骤\n",
        "\n",
        "如果您想了解有关 Keras 的更多详细信息，可以在[此处](./keras/)查看现有的 Keras 指南。\n",
        "\n",
        "在 `tf.module` 上构建的高级 API 的另一个示例是 DeepMind 的 Sonnet，[其网站](https://github.com/deepmind/sonnet)上有详细介绍。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "ISubpr_SSsiM"
      ],
      "name": "intro_to_modules.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
